#!/usr/bin/env python3
# -*- coding: utf-8 -*-
r"""
run_kernel.py — Vol3 kernel diagnostics (single-file pipeline)

Builds M(n) per context, checks spectral radius vs D(n), and
invertibility/conditioning. Writes CSVs/JSON + optional plots/hashes.

Windows / Anaconda usage
------------------------
conda activate fphs
cd C:\Users\kentn\vol3-kernel-diagnostics

REM Analytic (tri-diagonal, row-stochastic P; M = D·P)
python run_kernel.py --D D_values.csv --mode analytic --dim 5 --pivot_params pivot_params.json --outdir results --plots --hashes

REM Empirical (long-form rates: n,i,j,p; M = D·P)
python run_kernel.py --D D_values.csv --mode empirical --rates results\flip_rates_by_context.csv --outdir results --tol_eig 1e-6 --tol_inv 1e-12 --hashes
"""

from __future__ import annotations
import argparse, json, sys
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ------------------------------- I/O helpers -------------------------------

def read_D_map(path: Path) -> tuple[list[float], list[float]]:
    """
    Accept:
      - CSV with columns: n, D
      - JSON dict  {"n": D, ...}
      - JSON list  [{"n": n, "D": D}, ...]
    Return sorted ns, Ds by increasing n.
    """
    if not path.exists():
        raise FileNotFoundError(f"D file not found: {path}")
    if path.suffix.lower() == ".csv":
        df = pd.read_csv(path)
        if not {"n", "D"} <= set(df.columns):
            raise ValueError(f"{path} must have columns: n,D")
        pairs = sorted(zip(df["n"].astype(float), df["D"].astype(float)), key=lambda t: t[0])
        ns = [p[0] for p in pairs]; Ds = [p[1] for p in pairs]
        return ns, Ds

    data = json.loads(path.read_text(encoding="utf-8"))
    if isinstance(data, dict):
        items = sorted([(float(k), float(v)) for k, v in data.items()], key=lambda t: t[0])
    elif isinstance(data, list):
        items = sorted([(float(e["n"]), float(e["D"])) for e in data], key=lambda t: t[0])
    else:
        raise ValueError("Unsupported D map format")
    return [n for n,_ in items], [D for _,D in items]


def read_rates_longform(path: Path) -> dict[float, np.ndarray]:
    """
    Expect long-form CSV with columns: n,i,j,p
    Builds dict: n -> row-stochastic P (2D ndarray).
    """
    if not path.exists():
        raise FileNotFoundError(f"Rates file not found: {path}")
    df = pd.read_csv(path)
    req = {"n","i","j","p"}
    if not req <= set(df.columns):
        raise ValueError(f"{path} must have columns: n,i,j,p")
    P_by_n: dict[float, np.ndarray] = {}
    for n, sub in df.groupby("n"):
        n_f = float(n)
        i_max = int(sub["i"].max())
        j_max = int(sub["j"].max())
        dim = max(i_max, j_max) + 1
        P = np.zeros((dim, dim), dtype=float)
        for _, r in sub.iterrows():
            i = int(r["i"]); j = int(r["j"]); p = float(r["p"])
            P[i, j] = p
        rowsums = P.sum(axis=1)
        if not np.allclose(rowsums, 1.0, atol=1e-12):
            dev = float(np.max(np.abs(rowsums - 1.0)))
            raise SystemExit(f"Row sums not ~1 for n={n_f}: max dev={dev}")
        P_by_n[n_f] = P
    return P_by_n


# --------------------------- Canonical pivot loader -------------------------

def load_pivot_params(ppath: Path) -> tuple[float, float]:
    """
    Load pivot_params.json with fields: a, b.
    Enforce constraints: g(2)=1 (±1e-6) and a<0.
    """
    if not ppath.exists():
        alt = Path("pivot_params.json")
        if alt.exists():
            ppath = alt
        else:
            raise FileNotFoundError(f"pivot_params.json not found: {ppath}")
    data = json.loads(ppath.read_text(encoding="utf-8"))
    a = float(data["a"]); b = float(data["b"])
    g2 = a*2.0 + b
    if abs(g2 - 1.0) > 1e-6:
        raise ValueError(f"Constraint failed: g(2)={g2} must be 1 within 1e-6")
    if not (a < 0):
        raise ValueError(f"Constraint failed: slope a={a} must be < 0")
    return a, b

def g_of_D(D: float, a: float, b: float) -> float:
    return a*D + b


# ------------------------ Kernel construction routines ----------------------

def build_prob_tridiag_kernel(D: float, a: float, b: float, N: int, eps: float = 1e-12) -> np.ndarray:
    """
    Analytic kernel via a row-stochastic tri-diagonal P, then M = D * P.
    This guarantees ρ(M)=D (Perron root=1, scaled by D).

    We use g(D) as a raw off-diagonal weight and create probabilities by normalization.
    All probabilities are >= eps to avoid degenerate rows.
    """
    if N <= 0: raise ValueError("dim N must be >=1")
    if N == 1:
        return np.array([[D]], dtype=float)

    off = max(eps, g_of_D(D, a, b))     # raw off-diagonal weight (>=0)
    diag_w = max(eps, 2.0 - off)        # baseline diagonal raw weight (>=0, tunable)

    # Interior row normalization
    S_int = diag_w + 2.0*off
    p_off = off / S_int
    p_diag = diag_w / S_int

    # Endpoint row normalization
    S_end = diag_w + off
    p_off_end = off / S_end
    p_diag_end = diag_w / S_end

    P = np.zeros((N, N), dtype=float)
    for i in range(N):
        if i == 0:
            P[i,i]   = p_diag_end
            P[i,i+1] = p_off_end
        elif i == N-1:
            P[i,i]   = p_diag_end
            P[i,i-1] = p_off_end
        else:
            P[i,i]   = p_diag
            P[i,i-1] = p_off
            P[i,i+1] = p_off

    # Scale by D to get M
    return D * P


# -------------------------- Diagnostics & utilities -------------------------

def spectral_radius(M: np.ndarray) -> float:
    vals = np.linalg.eigvals(M)
    return float(np.max(np.abs(vals)))

def inf_norm(A: np.ndarray) -> float:
    return float(np.max(np.sum(np.abs(A), axis=1)))

def sha256_file(path: Path) -> str:
    import hashlib
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1024*1024), b""):
            h.update(chunk)
    return h.hexdigest()


# ---------------------------------- Main ------------------------------------

def main():
    ap = argparse.ArgumentParser(description="Vol3 kernel diagnostics (build M(n), check eigs & invertibility).")
    ap.add_argument("--D", required=True, help="Path to D_of_n.{csv|json}")
    ap.add_argument("--mode", choices=["analytic","empirical"], default="analytic",
                    help="analytic: tri-diagonal P via g(D), M=D·P; empirical: M=D·P from rates")
    ap.add_argument("--rates", help="flip_rates_by_context.csv (long form n,i,j,p). Required in empirical mode.")
    ap.add_argument("--dim", type=int, default=5, help="State dimension N for analytic mode (default 5).")
    ap.add_argument("--pivot_params", default="pivot_params.json", help="Path to pivot_params.json (a,b).")
    ap.add_argument("--outdir", default="results", help="Output directory (default results).")
    ap.add_argument("--tol_eig", type=float, default=1e-6, help="Tolerance for |lambda_max - D(n)| (default 1e-6).")
    ap.add_argument("--tol_inv", type=float, default=1e-12, help="Tolerance for ||M M^{-1} - I||_inf (default 1e-12).")
    ap.add_argument("--plots", action="store_true", help="Save rho_vs_D.png and one M(n) heatmap.")
    ap.add_argument("--hashes", action="store_true", help="Write results_hashes.json with SHA256 of outputs.")
    args = ap.parse_args()

    outdir = Path(args.outdir)
    kdir   = outdir / "kernels"
    outdir.mkdir(parents=True, exist_ok=True)
    kdir.mkdir(parents=True, exist_ok=True)

    D_path = Path(args.D)
    ns, Ds = read_D_map(D_path)

    # Build kernels
    if args.mode == "empirical":
        if not args.rates:
            raise SystemExit("--rates is required for empirical mode")
        P_by_n = read_rates_longform(Path(args.rates))
    else:
        a, b = load_pivot_params(Path(args.pivot_params))

    kernel_index = []
    for n, D in zip(ns, Ds):
        if args.mode == "analytic":
            M = build_prob_tridiag_kernel(float(D), a, b, args.dim)
        else:
            if n not in P_by_n:
                raise SystemExit(f"No rates found for n={n} in {args.rates}")
            P = P_by_n[n]
            M = float(D) * P
        kpath = kdir / f"M_n={n:g}.npy"
        np.save(kpath, M)
        kernel_index.append({"n": float(n), "path": str(kpath.as_posix()), "shape": list(M.shape)})

    # Save an index and specs
    (outdir / "kernel_index.json").write_text(json.dumps(kernel_index, indent=2), encoding="utf-8")
    specs = {
        "build_mode": args.mode,
        "mapping": "M(n)=D(n)·P(n) (row-stochastic tri-diagonal in analytic mode)",
        "dim": args.dim,
        "pivot_params": str(Path(args.pivot_params)),
        "D_source": str(D_path),
        "rates_source": str(args.rates) if args.rates else None,
        "tolerances": {"eig": args.tol_eig, "inv": args.tol_inv},
        "created_utc": datetime.utcnow().isoformat() + "Z"
    }
    if args.mode == "analytic":
        specs["pivot_ab"] = {"a": float(a), "b": float(b)}
    (outdir / "kernel_specs.json").write_text(json.dumps(specs, indent=2), encoding="utf-8")

    # Spectral radius checks
    eig_csv = outdir / "kernel_eigs.csv"
    lines = ["n,D_target,lambda_max,abs_diff,pass\n"]
    max_abs_diff = 0.0
    lam_list, D_list, n_list = [], [], []
    Dmap = dict(zip(ns, Ds))
    for ent in kernel_index:
        n = float(ent["n"])
        M = np.load(ent["path"])
        lam = spectral_radius(M)
        Dn = float(Dmap[n])
        diff = abs(lam - Dn)
        lines.append(f"{n},{Dn},{lam},{diff},{int(diff <= args.tol_eig)}\n")
        max_abs_diff = max(max_abs_diff, diff)
        n_list.append(n); D_list.append(Dn); lam_list.append(lam)
    eig_csv.write_text("".join(lines), encoding="utf-8")

    # Invertibility / conditioning checks
    inv_csv = outdir / "kernel_inv_checks.csv"
    lines = ["n,cond_number,max_mm_inv_residual,tol,pass\n"]
    worst_cond = 0.0
    worst_resid = 0.0
    for ent in kernel_index:
        n = float(ent["n"])
        M = np.load(ent["path"])
        I = np.eye(M.shape[0])
        try:
            Minv = np.linalg.inv(M)
        except np.linalg.LinAlgError:
            cond = float('inf'); resid = float('inf'); ok = False
        else:
            cond = float(np.linalg.cond(M))
            resid = inf_norm(M.dot(Minv) - I)
            ok = (resid <= args.tol_inv) and np.isfinite(cond)
        worst_cond = max(worst_cond, cond if np.isfinite(cond) else 1e300)
        worst_resid = max(worst_resid, resid if np.isfinite(resid) else 1e300)
        lines.append(f"{n},{cond},{resid},{args.tol_inv},{int(ok)}\n")
    inv_csv.write_text("".join(lines), encoding="utf-8")

    # Plots (optional)
    if args.plots:
        try:
            # ρ vs D, with y=x
            xs = D_list; ys = lam_list
            plt.figure()
            plt.plot(xs, ys, marker='o', linestyle='None', label='ρ(M^(i))')
            mn, mx = float(min(xs)), float(max(xs))
            plt.plot([mn, mx], [mn, mx], linestyle='--', label='y = x')
            plt.xlabel("D(n)"); plt.ylabel("ρ(M)")
            plt.title("Spectral Radius vs Fractal Dimension")
            plt.legend()
            plt.tight_layout()
            plt.savefig(outdir / "rho_vs_D.png", dpi=160)
            plt.close()
        except Exception as e:
            print(f"[plots] Could not create rho_vs_D.png: {e}", file=sys.stderr)

        try:
            if kernel_index:
                ent0 = kernel_index[len(kernel_index)//2]
                M0 = np.load(ent0["path"])
                plt.figure()
                plt.imshow(M0, aspect='equal')
                plt.colorbar()
                plt.title(f"Heatmap M(n={ent0['n']})")
                plt.tight_layout()
                plt.savefig(outdir / "M_heatmap.png", dpi=160)
                plt.close()
        except Exception as e:
            print(f"[plots] Could not create M_heatmap.png: {e}", file=sys.stderr)

    # Optional hashes
    if args.hashes:
        manifest = {}
        for p in [eig_csv, inv_csv, outdir / "kernel_index.json", outdir / "kernel_specs.json"]:
            if p.exists():
                manifest[str(p.as_posix())] = sha256_file(p)
        (outdir / "results_hashes.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")

    print(f"[OK] Wrote {eig_csv} (max |lambda_max - D| = {max_abs_diff})")
    print(f"[OK] Wrote {inv_csv} (worst cond={worst_cond}, worst residual={worst_resid})")
    if args.hashes:
        print(f"[OK] Wrote {outdir/'results_hashes.json'}")

if __name__ == "__main__":
    main()
